package org.example.utils

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, dense_rank}

import scala.collection.mutable.ListBuffer

object VertexUtil {

    /**
     * 将字符串id映射为节点id
     * @param dataframe
     * @param list
     * @return
     */
    def convertStringId2LongId(dataframe: DataFrame, list: ListBuffer[String]): DataFrame = {
        // get all vertex ids from edge dataframe
        val srcIdDF: DataFrame = dataframe.select("src").withColumnRenamed("src", "id")
        val dstIdDF: DataFrame = dataframe.select("dst").withColumnRenamed("dst", "id")
        val idDF               = srcIdDF.union(dstIdDF).distinct()
        //    idDF.show()

        // encode id to Long type using dense_rank, the encodeId has two columns: id, encodedId
        // then you need to save the encodeId to convert back for the algorithm's result.
        val encodeId = idDF.withColumn("encodedId", dense_rank().over(Window.orderBy("id")))
        encodeId.write.mode("overwrite").option("header", true).csv("/tmp/encodeId1")

        //    encodeId.printSchema()
        //    encodeId.show()
        list.append("共" + encodeId.count() + "个节点")
        println("==========================共" + encodeId.count() + "个节点==========================")

        // convert the edge data's src and dst
        val srcJoinDF = dataframe
                .join(encodeId)
                .where(col("src") === col("id"))
                .drop("src")
                .drop("id")
                .withColumnRenamed("encodedId", "src")
        srcJoinDF.cache()
        val dstJoinDF = srcJoinDF
                .join(encodeId)
                .where(col("dst") === col("id"))
                .drop("dst")
                .drop("id")
                .withColumnRenamed("encodedId", "dst")
        //    dstJoinDF.show()

        // make the first two columns of edge dataframe are src and dst id
        dstJoinDF.select("src", "dst", "weight")
    }

    /**
     * 将节点id映射为字符串id
     * @param spark
     * @param dataframe
     * @return
     */
    def reconvertLongId2StringId(spark: SparkSession, dataframe: DataFrame): DataFrame = {
        // the String id and Long id map data
        var schema = "id STRING, encodedId INT"
        val encodeId = spark.read.option("header", true).schema(schema).csv("/tmp/encodeId1")
        //    encodeId.show()

        encodeId
                .join(dataframe)
                .where(col("encodedId") === col("_id"))
                .drop("encodedId")
                .drop("_id")
    }

}
